{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# script to do experiments described in paper: Deep Interest Evolution Network for Click-Through Rate Prediction\n",
    "\n",
    "## how to run\n",
    "\n",
    "1. Please run prepare_neg.ipynb first."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "SEQ_MAX_LEN = 100 # maximum sequence length\n",
    "BATCH_SIZE = 128\n",
    "EMBEDDING_DIM = 18\n",
    "DNN_HIDDEN_SIZE = [200, 80]\n",
    "DNN_DROPOUT = 0.0\n",
    "TEST_RUN = False\n",
    "EPOCH = 2\n",
    "SEED = 10"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "\n",
    "import itertools\n",
    "from collections import Counter, OrderedDict\n",
    "\n",
    "import random\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import torch.nn.functional as F\n",
    "from sklearn.metrics import roc_auc_score\n",
    "\n",
    "from prediction_flow.features import Number, Category, Sequence, Features\n",
    "from prediction_flow.transformers.column import (\n",
    "    StandardScaler, CategoryEncoder, SequenceEncoder)\n",
    "\n",
    "from prediction_flow.pytorch.data import Dataset\n",
    "from prediction_flow.pytorch import WideDeep, DeepFM, DNN, DIN, DIEN, AttentionGroup\n",
    "\n",
    "from prediction_flow.pytorch.functions import fit, predict, create_dataloader_fn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<torch._C.Generator at 0x7f6631bf0ef0>"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "random.seed(SEED)\n",
    "np.random.seed(SEED)\n",
    "torch.manual_seed(SEED)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_df = pd.read_csv(\n",
    "    \"./local_train.csv\", sep='\\t')\n",
    "\n",
    "valid_df = pd.read_csv(\n",
    "    \"./local_test.csv\", sep='\\t')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "if TEST_RUN:\n",
    "    train_df = train_df.sample(1000)\n",
    "    valid_df = valid_df.sample(1000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>label</th>\n",
       "      <th>uid</th>\n",
       "      <th>mid</th>\n",
       "      <th>cat</th>\n",
       "      <th>hist_mids</th>\n",
       "      <th>hist_cats</th>\n",
       "      <th>neg_hist_mids</th>\n",
       "      <th>neg_hist_cats</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>AZPJ9LUT0FEPY</td>\n",
       "      <td>B00AMNNTIA</td>\n",
       "      <td>Literature &amp; Fiction</td>\n",
       "      <td>0307744434\u00020062248391\u00020470530707\u00020978924622\u000215...</td>\n",
       "      <td>Books\u0002Books\u0002Books\u0002Books\u0002Books</td>\n",
       "      <td>1449710247\u00020810984164\u00020615633129\u00020962121940\u0002B0...</td>\n",
       "      <td>Books\u0002Books\u0002Books\u0002Books\u0002Literature &amp; Fiction</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1</td>\n",
       "      <td>AZPJ9LUT0FEPY</td>\n",
       "      <td>0800731603</td>\n",
       "      <td>Books</td>\n",
       "      <td>0307744434\u00020062248391\u00020470530707\u00020978924622\u000215...</td>\n",
       "      <td>Books\u0002Books\u0002Books\u0002Books\u0002Books</td>\n",
       "      <td>0141017619\u00020736921680\u00020425258203\u0002160140462X\u000214...</td>\n",
       "      <td>Books\u0002Books\u0002Books\u0002Books\u0002Books</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>0</td>\n",
       "      <td>A2NRV79GKAU726</td>\n",
       "      <td>B003NNV10O</td>\n",
       "      <td>Russian</td>\n",
       "      <td>0814472869\u00020071462074\u00021583942300\u00020812538366\u0002B0...</td>\n",
       "      <td>Books\u0002Books\u0002Books\u0002Books\u0002Baking\u0002Books\u0002Books</td>\n",
       "      <td>051513287X\u00020231124694\u00021442409142\u00021118388461\u000219...</td>\n",
       "      <td>Books\u0002Books\u0002Books\u0002Books\u0002Books\u0002Books\u0002Books</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>1</td>\n",
       "      <td>A2NRV79GKAU726</td>\n",
       "      <td>B000UWJ91O</td>\n",
       "      <td>Books</td>\n",
       "      <td>0814472869\u00020071462074\u00021583942300\u00020812538366\u0002B0...</td>\n",
       "      <td>Books\u0002Books\u0002Books\u0002Books\u0002Baking\u0002Books\u0002Books</td>\n",
       "      <td>B00DNLWQ00\u00021250007119\u00020393077489\u00021591861179\u000202...</td>\n",
       "      <td>War\u0002Books\u0002Books\u0002Books\u0002Books\u0002Books\u0002Books</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>0</td>\n",
       "      <td>A2GEQVDX2LL4V3</td>\n",
       "      <td>0321334094</td>\n",
       "      <td>Books</td>\n",
       "      <td>0743596870\u00020374280991\u00021439140634\u00020976475731</td>\n",
       "      <td>Books\u0002Books\u0002Books\u0002Books</td>\n",
       "      <td>0764201816\u00020307265757\u00021466367741\u00020879462809</td>\n",
       "      <td>Books\u0002Books\u0002Books\u0002Books</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   label             uid         mid                   cat  \\\n",
       "0      0   AZPJ9LUT0FEPY  B00AMNNTIA  Literature & Fiction   \n",
       "1      1   AZPJ9LUT0FEPY  0800731603                 Books   \n",
       "2      0  A2NRV79GKAU726  B003NNV10O               Russian   \n",
       "3      1  A2NRV79GKAU726  B000UWJ91O                 Books   \n",
       "4      0  A2GEQVDX2LL4V3  0321334094                 Books   \n",
       "\n",
       "                                           hist_mids  \\\n",
       "0  0307744434\u00020062248391\u00020470530707\u00020978924622\u000215...   \n",
       "1  0307744434\u00020062248391\u00020470530707\u00020978924622\u000215...   \n",
       "2  0814472869\u00020071462074\u00021583942300\u00020812538366\u0002B0...   \n",
       "3  0814472869\u00020071462074\u00021583942300\u00020812538366\u0002B0...   \n",
       "4        0743596870\u00020374280991\u00021439140634\u00020976475731   \n",
       "\n",
       "                                    hist_cats  \\\n",
       "0               Books\u0002Books\u0002Books\u0002Books\u0002Books   \n",
       "1               Books\u0002Books\u0002Books\u0002Books\u0002Books   \n",
       "2  Books\u0002Books\u0002Books\u0002Books\u0002Baking\u0002Books\u0002Books   \n",
       "3  Books\u0002Books\u0002Books\u0002Books\u0002Baking\u0002Books\u0002Books   \n",
       "4                     Books\u0002Books\u0002Books\u0002Books   \n",
       "\n",
       "                                       neg_hist_mids  \\\n",
       "0  1449710247\u00020810984164\u00020615633129\u00020962121940\u0002B0...   \n",
       "1  0141017619\u00020736921680\u00020425258203\u0002160140462X\u000214...   \n",
       "2  051513287X\u00020231124694\u00021442409142\u00021118388461\u000219...   \n",
       "3  B00DNLWQ00\u00021250007119\u00020393077489\u00021591861179\u000202...   \n",
       "4        0764201816\u00020307265757\u00021466367741\u00020879462809   \n",
       "\n",
       "                                  neg_hist_cats  \n",
       "0  Books\u0002Books\u0002Books\u0002Books\u0002Literature & Fiction  \n",
       "1                 Books\u0002Books\u0002Books\u0002Books\u0002Books  \n",
       "2     Books\u0002Books\u0002Books\u0002Books\u0002Books\u0002Books\u0002Books  \n",
       "3       War\u0002Books\u0002Books\u0002Books\u0002Books\u0002Books\u0002Books  \n",
       "4                       Books\u0002Books\u0002Books\u0002Books  "
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>label</th>\n",
       "      <th>uid</th>\n",
       "      <th>mid</th>\n",
       "      <th>cat</th>\n",
       "      <th>hist_mids</th>\n",
       "      <th>hist_cats</th>\n",
       "      <th>neg_hist_mids</th>\n",
       "      <th>neg_hist_cats</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>A3BI7R43VUZ1TY</td>\n",
       "      <td>B00JNHU0T2</td>\n",
       "      <td>Literature &amp; Fiction</td>\n",
       "      <td>0989464105\u0002B00B01691C\u00021477809732\u00021608442845</td>\n",
       "      <td>Books\u0002Literature &amp; Fiction\u0002Books\u0002Books</td>\n",
       "      <td>1440500177\u0002B00JRENU3Y\u00020802118615\u00020007285248</td>\n",
       "      <td>Books\u0002Literature &amp; Fiction\u0002Books\u0002Books</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1</td>\n",
       "      <td>A3BI7R43VUZ1TY</td>\n",
       "      <td>0989464121</td>\n",
       "      <td>Books</td>\n",
       "      <td>0989464105\u0002B00B01691C\u00021477809732\u00021608442845</td>\n",
       "      <td>Books\u0002Literature &amp; Fiction\u0002Books\u0002Books</td>\n",
       "      <td>B00KFVKIL0\u0002B00KPBM8TA\u00021560850116\u00021599901218</td>\n",
       "      <td>Literature &amp; Fiction\u0002Herbal Remedies\u0002Books\u0002Books</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>0</td>\n",
       "      <td>A2Z3AHJPXG3ZNP</td>\n",
       "      <td>B0072YSPJ0</td>\n",
       "      <td>Literature &amp; Fiction</td>\n",
       "      <td>1478310960\u00021492231452\u00021477603425\u0002B00FRKLA6Q</td>\n",
       "      <td>Books\u0002Books\u0002Books\u0002Urban</td>\n",
       "      <td>1582702233\u00021439182450\u00020373296762\u00029963616089</td>\n",
       "      <td>Books\u0002Books\u0002Books\u0002Books</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>1</td>\n",
       "      <td>A2Z3AHJPXG3ZNP</td>\n",
       "      <td>B00G4I4I5U</td>\n",
       "      <td>Urban</td>\n",
       "      <td>1478310960\u00021492231452\u00021477603425\u0002B00FRKLA6Q</td>\n",
       "      <td>Books\u0002Books\u0002Books\u0002Urban</td>\n",
       "      <td>B00CH09DKY\u0002B007MSBKES\u00020679745890\u0002B00EZYZ7VO</td>\n",
       "      <td>Dogs\u0002Action &amp; Adventure\u0002Books\u0002Marketing</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>0</td>\n",
       "      <td>A2KDDPJUNWC5CA</td>\n",
       "      <td>0316228532</td>\n",
       "      <td>Books</td>\n",
       "      <td>0141326085\u0002031026622X\u00020316077046\u00020988649179\u000214...</td>\n",
       "      <td>Books\u0002Books\u0002Books\u0002Books\u0002Books</td>\n",
       "      <td>0435550268\u00020778313271\u0002B00DOM0YT8\u00021494965585\u0002B0...</td>\n",
       "      <td>Books\u0002Books\u0002Coming of Age\u0002Books\u0002Flower Arranging</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   label             uid         mid                   cat  \\\n",
       "0      0  A3BI7R43VUZ1TY  B00JNHU0T2  Literature & Fiction   \n",
       "1      1  A3BI7R43VUZ1TY  0989464121                 Books   \n",
       "2      0  A2Z3AHJPXG3ZNP  B0072YSPJ0  Literature & Fiction   \n",
       "3      1  A2Z3AHJPXG3ZNP  B00G4I4I5U                 Urban   \n",
       "4      0  A2KDDPJUNWC5CA  0316228532                 Books   \n",
       "\n",
       "                                           hist_mids  \\\n",
       "0        0989464105\u0002B00B01691C\u00021477809732\u00021608442845   \n",
       "1        0989464105\u0002B00B01691C\u00021477809732\u00021608442845   \n",
       "2        1478310960\u00021492231452\u00021477603425\u0002B00FRKLA6Q   \n",
       "3        1478310960\u00021492231452\u00021477603425\u0002B00FRKLA6Q   \n",
       "4  0141326085\u0002031026622X\u00020316077046\u00020988649179\u000214...   \n",
       "\n",
       "                                hist_cats  \\\n",
       "0  Books\u0002Literature & Fiction\u0002Books\u0002Books   \n",
       "1  Books\u0002Literature & Fiction\u0002Books\u0002Books   \n",
       "2                 Books\u0002Books\u0002Books\u0002Urban   \n",
       "3                 Books\u0002Books\u0002Books\u0002Urban   \n",
       "4           Books\u0002Books\u0002Books\u0002Books\u0002Books   \n",
       "\n",
       "                                       neg_hist_mids  \\\n",
       "0        1440500177\u0002B00JRENU3Y\u00020802118615\u00020007285248   \n",
       "1        B00KFVKIL0\u0002B00KPBM8TA\u00021560850116\u00021599901218   \n",
       "2        1582702233\u00021439182450\u00020373296762\u00029963616089   \n",
       "3        B00CH09DKY\u0002B007MSBKES\u00020679745890\u0002B00EZYZ7VO   \n",
       "4  0435550268\u00020778313271\u0002B00DOM0YT8\u00021494965585\u0002B0...   \n",
       "\n",
       "                                      neg_hist_cats  \n",
       "0            Books\u0002Literature & Fiction\u0002Books\u0002Books  \n",
       "1  Literature & Fiction\u0002Herbal Remedies\u0002Books\u0002Books  \n",
       "2                           Books\u0002Books\u0002Books\u0002Books  \n",
       "3           Dogs\u0002Action & Adventure\u0002Books\u0002Marketing  \n",
       "4  Books\u0002Books\u0002Coming of Age\u0002Books\u0002Flower Arranging  "
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "valid_df.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# EDA"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "def scale_eda(df):\n",
    "    print(df.shape)\n",
    "    print(df.uid.nunique())\n",
    "    print(df.mid.nunique())\n",
    "    print(df.groupby('label', as_index=False).uid.count())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(1086120, 8)\n",
      "543060\n",
      "261895\n",
      "   label     uid\n",
      "0      0  543060\n",
      "1      1  543060\n",
      "(121216, 8)\n",
      "60608\n",
      "75053\n",
      "   label    uid\n",
      "0      0  60608\n",
      "1      1  60608\n"
     ]
    }
   ],
   "source": [
    "scale_eda(train_df)\n",
    "scale_eda(valid_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['0307744434', '0062248391', '0470530707', '0978924622', '1590516400']"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_df.values[0][4].split('\\x02')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**This data set is well balanced. Each user has two samples, pos sample and neg sample.**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "unique_cats = Counter(train_df.cat.values.tolist())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "unique_cats_in_hist = Counter(\n",
    "    itertools.chain(*train_df.hist_cats.apply(lambda x: x.split(\"\u0002\")).values.tolist()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1459 1600 1459\n"
     ]
    }
   ],
   "source": [
    "print(len(unique_cats), len(unique_cats_in_hist),\n",
    "      len(np.intersect1d(list(unique_cats.keys()), list(unique_cats_in_hist.keys()))))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**All categorys also appear in history categorys.**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "unique_mids = Counter(train_df.mid.values.tolist())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "unique_mids_in_hist = Counter(\n",
    "    itertools.chain(*train_df.hist_mids.apply(lambda x: x.split(\"\u0002\")).values.tolist()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "261895 367788 261701\n"
     ]
    }
   ],
   "source": [
    "print(len(unique_mids), len(unique_mids_in_hist),\n",
    "      len(np.intersect1d(list(unique_mids.keys()), list(unique_mids_in_hist.keys()))))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Most mids appears in history mids.**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "There are 86.27769709405354% mid overlap between train and valid\n"
     ]
    }
   ],
   "source": [
    "print(\"There are {}% mid overlap between train and valid\".format(\n",
    "    100 * len(np.intersect1d(train_df.mid.unique(), valid_df.mid.unique())) / len(valid_df.mid.unique())))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "There are 97.91208791208791% mid overlap between train and valid\n"
     ]
    }
   ],
   "source": [
    "print(\"There are {}% mid overlap between train and valid\".format(\n",
    "    100 * len(np.intersect1d(train_df.cat.unique(), valid_df.cat.unique())) / len(valid_df.cat.unique())))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# define features"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "cat_enc = SequenceEncoder(sep=\"\\x02\", min_cnt=1, max_len=SEQ_MAX_LEN)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<prediction_flow.transformers.column.sequence_encoder.SequenceEncoder at 0x7f6584b6cf28>"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "cat_enc.fit(train_df.hist_cats.values)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "cat_word2idx, cat_idx2word = cat_enc.word2idx, cat_enc.idx2word"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1602\n"
     ]
    }
   ],
   "source": [
    "print(len(cat_word2idx))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "mid_enc = SequenceEncoder(sep=\"\\x02\", min_cnt=1, max_len=SEQ_MAX_LEN)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<prediction_flow.transformers.column.sequence_encoder.SequenceEncoder at 0x7f658463cd30>"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mid_enc.fit(np.vstack([train_df.mid.values, train_df.hist_mids.values]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "mid_word2idx, mid_idx2word = mid_enc.word2idx, mid_enc.idx2word"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "367984\n"
     ]
    }
   ],
   "source": [
    "print(len(mid_word2idx))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "number_features = []\n",
    "\n",
    "category_features = [\n",
    "    Category('mid',\n",
    "             CategoryEncoder(min_cnt=1, word2idx=mid_word2idx, idx2word=mid_idx2word),\n",
    "             embedding_name='mid'),\n",
    "    Category('cat',\n",
    "             CategoryEncoder(min_cnt=1, word2idx=cat_word2idx, idx2word=cat_idx2word),\n",
    "             embedding_name='cat'),\n",
    "]\n",
    "\n",
    "sequence_features = [\n",
    "    Sequence('hist_mids',\n",
    "             SequenceEncoder(sep=\"\\x02\", min_cnt=1, max_len=SEQ_MAX_LEN,\n",
    "                             word2idx=mid_word2idx, idx2word=mid_idx2word),\n",
    "             embedding_name='mid'),\n",
    "    Sequence('hist_cats',\n",
    "             SequenceEncoder(sep=\"\\x02\", min_cnt=1, max_len=SEQ_MAX_LEN,\n",
    "                             word2idx=cat_word2idx, idx2word=cat_idx2word),\n",
    "             embedding_name='cat')\n",
    "]\n",
    "\n",
    "features, train_loader, valid_loader = create_dataloader_fn(\n",
    "    number_features, category_features, sequence_features, BATCH_SIZE, train_df, 'label', valid_df, 4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluation(model, df, dataloader):\n",
    "    preds = predict(model, dataloader)\n",
    "    return roc_auc_score(df['label'], preds.ravel())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [],
   "source": [
    "din_attention_groups = [\n",
    "    AttentionGroup(\n",
    "        name='group1',\n",
    "        pairs=[{'ad': 'mid', 'pos_hist': 'hist_mids'},\n",
    "               {'ad': 'cat', 'pos_hist': 'hist_cats'}],\n",
    "        hidden_layers=[80, 40], att_dropout=0.0)]\n",
    "\n",
    "gru_attention_groups = [\n",
    "    AttentionGroup(\n",
    "        name='group1',\n",
    "        pairs=[{'ad': 'mid', 'pos_hist': 'hist_mids'},\n",
    "               {'ad': 'cat', 'pos_hist': 'hist_cats'}],\n",
    "        hidden_layers=[80, 40], att_dropout=0.0, gru_type='GRU')]\n",
    "\n",
    "aigru_attention_groups = [\n",
    "    AttentionGroup(\n",
    "        name='group1',\n",
    "        pairs=[{'ad': 'mid', 'pos_hist': 'hist_mids'},\n",
    "               {'ad': 'cat', 'pos_hist': 'hist_cats'}],\n",
    "        hidden_layers=[80, 40], att_dropout=0.0, gru_type='AIGRU')]\n",
    "\n",
    "agru_attention_groups = [\n",
    "    AttentionGroup(\n",
    "        name='group1',\n",
    "        pairs=[{'ad': 'mid', 'pos_hist': 'hist_mids'},\n",
    "               {'ad': 'cat', 'pos_hist': 'hist_cats'}],\n",
    "        hidden_layers=[80, 40], att_dropout=0.0, gru_type='AGRU')]\n",
    "\n",
    "augru_attention_groups = [\n",
    "    AttentionGroup(\n",
    "        name='group1',\n",
    "        pairs=[{'ad': 'mid', 'pos_hist': 'hist_mids'},\n",
    "               {'ad': 'cat', 'pos_hist': 'hist_cats'}],\n",
    "        hidden_layers=[80, 40], att_dropout=0.0, gru_type='AUGRU')]\n",
    "\n",
    "models = [\n",
    "    ('DNN', DNN(features, 2, EMBEDDING_DIM, DNN_HIDDEN_SIZE,\n",
    "        final_activation='sigmoid', dropout=DNN_DROPOUT)),\n",
    "    \n",
    "    ('WideDeep', WideDeep(features,\n",
    "             wide_features=['mid', 'hist_mids', 'cat', 'hist_cats'],\n",
    "             deep_features=['mid', 'hist_mids', 'cat', 'hist_cats'],\n",
    "             cross_features=[('mid', 'hist_mids'), ('cat', 'hist_cats')],\n",
    "             num_classes=2, embedding_size=EMBEDDING_DIM, hidden_layers=DNN_HIDDEN_SIZE,\n",
    "             final_activation='sigmoid', dropout=DNN_DROPOUT)),\n",
    "    \n",
    "    ('DeepFM', DeepFM(features, 2, EMBEDDING_DIM, DNN_HIDDEN_SIZE, \n",
    "           final_activation='sigmoid', dropout=DNN_DROPOUT)),\n",
    "    \n",
    "    ('DIN', DIN(features, din_attention_groups, 2, EMBEDDING_DIM, DNN_HIDDEN_SIZE,\n",
    "        final_activation='sigmoid', dropout=DNN_DROPOUT)),\n",
    "    \n",
    "    ('DIEN_gru', DIEN(features, gru_attention_groups, 2, EMBEDDING_DIM, DNN_HIDDEN_SIZE,\n",
    "         final_activation='sigmoid', dropout=DNN_DROPOUT)),\n",
    "    \n",
    "    ('DIEN_aigru', DIEN(features, aigru_attention_groups, 2, EMBEDDING_DIM, DNN_HIDDEN_SIZE,\n",
    "         final_activation='sigmoid', dropout=DNN_DROPOUT)),\n",
    "    \n",
    "    ('DIEN_agru', DIEN(features, agru_attention_groups, 2, EMBEDDING_DIM, DNN_HIDDEN_SIZE,\n",
    "         final_activation='sigmoid', dropout=DNN_DROPOUT)),\n",
    "    \n",
    "    ('DIEN_augru', DIEN(features, augru_attention_groups, 2, EMBEDDING_DIM, DNN_HIDDEN_SIZE,\n",
    "         final_activation='sigmoid', dropout=DNN_DROPOUT))\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [],
   "source": [
    "def run(models):\n",
    "    scores = OrderedDict()\n",
    "    model_loss_curves = OrderedDict()\n",
    "    for model_name, model in models:\n",
    "        print(model_name)\n",
    "        loss_func = nn.BCELoss()\n",
    "        optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)\n",
    "        losses = fit(EPOCH, model, loss_func, optimizer,\n",
    "            train_loader, valid_loader, notebook=True, auxiliary_loss_rate=1)\n",
    "        scores[model_name] = evaluation(model, valid_df, valid_loader)\n",
    "        model_loss_curves[model_name] = losses\n",
    "    return scores, model_loss_curves"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "DNN\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "7fbb7f9f597b47f69b13b9c5998d5e77",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, description='training routine', max=2, style=ProgressStyle(description_wid…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "fd2587dc82474e80a63bfe510c84b490",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, description='train', max=8486, style=ProgressStyle(description_width='init…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c1d093792aaf487fbc5847c163eb213f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, description='valid', max=947, style=ProgressStyle(description_width='initi…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "GPU is available, transfer model to GPU.\n",
      "WideDeep\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "02c3966bdb354e919bb2952192667c79",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, description='training routine', max=2, style=ProgressStyle(description_wid…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e9581d10bd204e6dbe5a26a584db04f5",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, description='train', max=8486, style=ProgressStyle(description_width='init…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "8434713aa4c841f9b65613df683130e8",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, description='valid', max=947, style=ProgressStyle(description_width='initi…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "GPU is available, transfer model to GPU.\n",
      "DeepFM\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "1dec42dae7e0496a889c2f6591a10cb5",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, description='training routine', max=2, style=ProgressStyle(description_wid…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e847e08baae64a39bc1616962adf4033",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, description='train', max=8486, style=ProgressStyle(description_width='init…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a06bf1b36a34409180f2c5cbea4f8646",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, description='valid', max=947, style=ProgressStyle(description_width='initi…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "GPU is available, transfer model to GPU.\n",
      "DIN\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e6894f61520648b38caff6a688114ff3",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, description='training routine', max=2, style=ProgressStyle(description_wid…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "4fda07da268b44be9e29e716d6a6bf5f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, description='train', max=8486, style=ProgressStyle(description_width='init…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "7b1dccbbad7a48c2a696fca9baed1018",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, description='valid', max=947, style=ProgressStyle(description_width='initi…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "GPU is available, transfer model to GPU.\n",
      "DIEN_gru\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d8175a1831094942a5de454e05d91270",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, description='training routine', max=2, style=ProgressStyle(description_wid…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "da6461d19bc9409e8beeade92ab24b33",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, description='train', max=8486, style=ProgressStyle(description_width='init…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f945b88d677e4180aecb8decb9b51b39",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, description='valid', max=947, style=ProgressStyle(description_width='initi…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "GPU is available, transfer model to GPU.\n",
      "DIEN_aigru\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2aba28aff938494dbd712b5e65d9b9f1",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, description='training routine', max=2, style=ProgressStyle(description_wid…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "00b085c979134a6f9da3d90c728c22ff",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, description='train', max=8486, style=ProgressStyle(description_width='init…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3e4b48c48a05458d82a66b80dc27b4ee",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, description='valid', max=947, style=ProgressStyle(description_width='initi…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "GPU is available, transfer model to GPU.\n",
      "DIEN_agru\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "055fe36d1823434b91032e8f058637ee",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, description='training routine', max=2, style=ProgressStyle(description_wid…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "0a16eec4e2fa4121a8089c85c41112b0",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, description='train', max=8486, style=ProgressStyle(description_width='init…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c845278be4c845188ab8a04b60e7d072",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, description='valid', max=947, style=ProgressStyle(description_width='initi…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "GPU is available, transfer model to GPU.\n",
      "DIEN_augru\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "dfa1d0f1a97648c0b9cc8e91b46809d2",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, description='training routine', max=2, style=ProgressStyle(description_wid…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "9da0519e5fde48acad9c0352178455a3",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, description='train', max=8486, style=ProgressStyle(description_width='init…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f1f332353c19486a937d3aeb3cdcd999",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, description='valid', max=947, style=ProgressStyle(description_width='initi…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "GPU is available, transfer model to GPU.\n"
     ]
    }
   ],
   "source": [
    "scores1, model_loss_curves1 = run(models)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "OrderedDict([('DNN', 0.7414861590544137), ('WideDeep', 0.7366498651126783), ('DeepFM', 0.7493579010827383), ('DIN', 0.7751056232180309), ('DIEN_gru', 0.7771685541807118), ('DIEN_aigru', 0.7738427067568526), ('DIEN_agru', 0.7807896280337772), ('DIEN_augru', 0.7816471855600924)])\n"
     ]
    }
   ],
   "source": [
    "print(scores1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "OrderedDict([('DNN', [{'train_loss': 0.6545418111215432, 'valid_loss': 0.6306317368681604, 'train_auxiliary_loss': 0, 'valid_auxiliary_loss': 0}, {'train_loss': 0.6006484859827803, 'valid_loss': 0.5946546539963232, 'train_auxiliary_loss': 0, 'valid_auxiliary_loss': 0}]), ('WideDeep', [{'train_loss': 0.6992310519463434, 'valid_loss': 0.6448254635618511, 'train_auxiliary_loss': 0, 'valid_auxiliary_loss': 0}, {'train_loss': 0.6043528134344365, 'valid_loss': 0.6019523961679991, 'train_auxiliary_loss': 0, 'valid_auxiliary_loss': 0}]), ('DeepFM', [{'train_loss': 0.9217421180679916, 'valid_loss': 0.6287330887511505, 'train_auxiliary_loss': 0, 'valid_auxiliary_loss': 0}, {'train_loss': 0.5938254600035651, 'valid_loss': 0.5880784660483367, 'train_auxiliary_loss': 0, 'valid_auxiliary_loss': 0}]), ('DIN', [{'train_loss': 0.6470800264317711, 'valid_loss': 0.6101615355284199, 'train_auxiliary_loss': 0, 'valid_auxiliary_loss': 0}, {'train_loss': 0.5630607244901257, 'valid_loss': 0.5642400029279607, 'train_auxiliary_loss': 0, 'valid_auxiliary_loss': 0}]), ('DIEN_gru', [{'train_loss': 0.6420232596232783, 'valid_loss': 0.60818177304273, 'train_auxiliary_loss': 0, 'valid_auxiliary_loss': 0}, {'train_loss': 0.5592777577072775, 'valid_loss': 0.5618925285314175, 'train_auxiliary_loss': 0, 'valid_auxiliary_loss': 0}]), ('DIEN_aigru', [{'train_loss': 0.64573271304629, 'valid_loss': 0.6104722710949825, 'train_auxiliary_loss': 0, 'valid_auxiliary_loss': 0}, {'train_loss': 0.5608341148719748, 'valid_loss': 0.5651352940227569, 'train_auxiliary_loss': 0, 'valid_auxiliary_loss': 0}]), ('DIEN_agru', [{'train_loss': 0.6416976842260987, 'valid_loss': 0.6146120481909014, 'train_auxiliary_loss': 0, 'valid_auxiliary_loss': 0}, {'train_loss': 0.5551212878369666, 'valid_loss': 0.5583103171057032, 'train_auxiliary_loss': 0, 'valid_auxiliary_loss': 0}]), ('DIEN_augru', [{'train_loss': 0.6400591503357239, 'valid_loss': 0.5996584904231639, 'train_auxiliary_loss': 0, 'valid_auxiliary_loss': 0}, {'train_loss': 0.5508024052308748, 'valid_loss': 0.5573866153342417, 'train_auxiliary_loss': 0, 'valid_auxiliary_loss': 0}])])\n"
     ]
    }
   ],
   "source": [
    "print(model_loss_curves1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [],
   "source": [
    "number_features = []\n",
    "\n",
    "category_features = [\n",
    "    Category('mid',\n",
    "             CategoryEncoder(min_cnt=1, word2idx=mid_word2idx, idx2word=mid_idx2word),\n",
    "             embedding_name='mid'),\n",
    "    Category('cat',\n",
    "             CategoryEncoder(min_cnt=1, word2idx=cat_word2idx, idx2word=cat_idx2word),\n",
    "             embedding_name='cat'),\n",
    "]\n",
    "\n",
    "sequence_features = [\n",
    "    Sequence('hist_mids',\n",
    "             SequenceEncoder(sep=\"\u0002\", min_cnt=1, max_len=SEQ_MAX_LEN,\n",
    "                             word2idx=mid_word2idx, idx2word=mid_idx2word),\n",
    "             embedding_name='mid'),\n",
    "    Sequence('hist_cats',\n",
    "             SequenceEncoder(sep=\"\u0002\", min_cnt=1, max_len=SEQ_MAX_LEN,\n",
    "                             word2idx=cat_word2idx, idx2word=cat_idx2word),\n",
    "             embedding_name='cat'),\n",
    "    Sequence('neg_hist_mids',\n",
    "             SequenceEncoder(sep=\"\u0002\", min_cnt=1, max_len=SEQ_MAX_LEN,\n",
    "                             word2idx=mid_word2idx, idx2word=mid_idx2word),\n",
    "             embedding_name='mid'),\n",
    "    Sequence('neg_hist_cats',\n",
    "             SequenceEncoder(sep=\"\u0002\", min_cnt=1, max_len=SEQ_MAX_LEN,\n",
    "                             word2idx=cat_word2idx, idx2word=cat_idx2word),\n",
    "             embedding_name='cat')\n",
    "]\n",
    "\n",
    "features, train_loader, valid_loader = create_dataloader_fn(\n",
    "    number_features, category_features, sequence_features, BATCH_SIZE, train_df, 'label', valid_df, 4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [],
   "source": [
    "augru_attention_groups_with_neg = [\n",
    "    AttentionGroup(\n",
    "        name='group1',\n",
    "        pairs=[{'ad': 'mid', 'pos_hist': 'hist_mids', 'neg_hist': 'neg_hist_mids'},\n",
    "               {'ad': 'cat', 'pos_hist': 'hist_cats', 'neg_hist': 'neg_hist_cats'}],\n",
    "        hidden_layers=[80, 40], att_dropout=0.0, gru_type='AUGRU')]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [],
   "source": [
    "models = [\n",
    "    ('DIEN', DIEN(features, augru_attention_groups_with_neg, 2, EMBEDDING_DIM, DNN_HIDDEN_SIZE,\n",
    "         final_activation='sigmoid', dropout=DNN_DROPOUT, use_negsampling=True))\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "DIEN\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c09ca2a683b340229c9fc53a92de0996",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, description='training routine', max=2, style=ProgressStyle(description_wid…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "da906c88cc91459ab897e4d38b079220",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, description='train', max=8486, style=ProgressStyle(description_width='init…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f5a3deff836e4c3aaaa3b9e0cb1d99c3",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, description='valid', max=947, style=ProgressStyle(description_width='initi…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "GPU is available, transfer model to GPU.\n"
     ]
    }
   ],
   "source": [
    "scores2, model_loss_curves2 = run(models)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "OrderedDict([('DIEN', 0.7780708151545855)])\n"
     ]
    }
   ],
   "source": [
    "print(scores2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "OrderedDict([('DIEN', [{'train_loss': 0.6435200786944595, 'valid_loss': 0.6058936769884274, 'train_auxiliary_loss': 0.9507109770411276, 'valid_auxiliary_loss': 0.9529147756766867}, {'train_loss': 0.5532998496347418, 'valid_loss': 0.5623541603055653, 'train_auxiliary_loss': 0.9504248710376004, 'valid_auxiliary_loss': 0.9479448593279634}])])\n"
     ]
    }
   ],
   "source": [
    "print(model_loss_curves2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [],
   "source": [
    "scores = scores1.copy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [],
   "source": [
    "scores['DIEN'] = scores2['DIEN']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "OrderedDict([('DNN', 0.7414861590544137),\n",
       "             ('WideDeep', 0.7366498651126783),\n",
       "             ('DeepFM', 0.7493579010827383),\n",
       "             ('DIN', 0.7751056232180309),\n",
       "             ('DIEN_gru', 0.7771685541807118),\n",
       "             ('DIEN_aigru', 0.7738427067568526),\n",
       "             ('DIEN_agru', 0.7807896280337772),\n",
       "             ('DIEN_augru', 0.7816471855600924),\n",
       "             ('DIEN', 0.7780708151545855)])"
      ]
     },
     "execution_count": 47,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "scores"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [],
   "source": [
    "def autolabel(rects):\n",
    "    \"\"\"\n",
    "    Attach a text label above each bar displaying its height\n",
    "    \"\"\"\n",
    "    for rect in rects:\n",
    "        height = rect.get_height()\n",
    "        ax.text(rect.get_x() + rect.get_width()/2., 1.05*height,\n",
    "                '%f' % height,\n",
    "                ha='center', va='bottom')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[]"
      ]
     },
     "execution_count": 49,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAmEAAAE0CAYAAABkXuSSAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjAsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+17YcXAAAgAElEQVR4nO3deXxU5dn/8e8lCFZxQYwtJOwEIYEQICguTxWtCi6gbZTFKuKCvyoWFBfUp7g8tupj1aosLn20LpVFbTVVRK2IWkVZZFF2BISA1YgsigqEXL8/ZjKdTCYhgZycIXzer1dezjnnnnOuexIP3znLfczdBQAAgNq1X9gFAAAA7IsIYQAAACEghAEAAISAEAYAABACQhgAAEAICGEAAAAhCDSEmVlvM1tqZivMbFSS5S3N7C0zW2Bm080sI8h6AAAAUoUFNU6YmdWTtEzSqZIKJc2SNNDdF8W1eV7SK+7+lJmdLGmIu18YSEEAAAApJMgjYUdLWuHuK919u6SJkvoltMmSNC36+u0kywEAAOqkIENYuqS1cdOF0Xnx5kv6ZfT1uZIONrMmAdYEAACQEuqHvP3rJI0xs4slvStpnaSdiY3MbKikoZJ00EEHde/QoUNt1ggAALBb5syZ87W7pyVbFmQIWyepedx0RnRejLuvV/RImJk1kvQrd9+UuCJ3f0zSY5KUl5fns2fPDqpmAACAGmNmn1e0LMjTkbMkZZpZazNrIGmApIKEwo4ws9IabpL0RID1AAAApIzAQpi7F0saJul1SYslTXb3hWZ2h5n1jTY7SdJSM1sm6aeSfh9UPQAAAKkksCEqgsLpSAAAsLcwsznunpdsGSPmAwAAhIAQBgAAEAJCGAAAQAgIYQAAACEghAEAAISAEAYAABACQhgAAEAICGEAAAAhIIQBAACEgBAGAAjF1KlTddRRR6ldu3a6++67yy2/5pprlJubq9zcXLVv316HHXZYbNkNN9yg7OxsdezYUb/97W9V+vSXOXPmqHPnzmrXrl2Z+f3794+tq1WrVsrNzZUkbd++XUOGDFHnzp3VpUsXTZ8+PbaNitaVCv285ZZb1Lx5czVq1Kjc+iZPnqysrCxlZ2dr0KBBZZZt2bJFGRkZGjZsmCTp22+/jW07NzdXRxxxhEaMGLFb/cRucPe96qd79+4OANi7FRcXe5s2bfyzzz7zbdu2eU5Oji9cuLDC9g899JAPGTLE3d3ff/99P+6447y4uNiLi4u9Z8+e/vbbb7u7e48ePXzGjBleUlLivXv39ilTppRb17XXXuu33367u7uPGTPGL774Ynd3//LLL71bt26+c+fOKq8rrH7OmDHD169f7wcddFCZ9y9btsxzc3P9m2++ifUp3m9/+1sfOHCgX3XVVUm3361bN3/nnXeq3U9399dee83bt2/vbdu29bvuuqvc8hEjRniXLl28S5cunpmZ6Yceeqi7u0+bNi02v0uXLt6wYUP/+9//7u7u//znP71r167epUsXP/7443358uXu7v7jjz/6+eef723btvWjjz7aV61a5e7uzz77bJl1mZnPnTvX3d1vvvlmz8jIKPeZBU3SbK8g04Qeqqr7QwgDkgtiBzh48GBv1apVbFnpzmzx4sXes2dPb9Cggd97771VqqOkpMRvvvlmz8zM9A4dOviDDz6YMv084YQTYvObNm3q/fr122U/N27c6L/61a/8qKOO8g4dOvgHH3zg7u7z5s3znj17eqdOnfyss87yzZs371Y/67oPPvjATzvttNj0H/7wB//DH/5QYftjjz3W33jjjdh7u3Xr5t9//71v3brVu3fv7osWLfL169f7UUcdFXvPc88950OHDi2znpKSEs/IyPBly5a5u/uVV17pTz/9dGz5ySef7B999FGV1hVWP+MlBorrr7/eH3/88aTrnj17tvfv39+ffPLJpCFs6dKlnpGR4SUlJVXuX6k9CZvxNmzY4I0bN/atW7e6u3tmZmasz2PHjvXBgwfHXl9xxRXu7j5hwgQ///zzy61rwYIF3qZNm9h0RcE1aJWFME5HAnXAzp07ddVVV+m1117TokWLNGHCBC1atKhMmwceeEDz5s3TvHnzdPXVV+uXv/ylJKlXr16x+dOmTdOBBx6o0047Lfa+e++9N7a89BTO4YcfroceekjXXXddlev4y1/+orVr12rJkiVavHixBgwYkDL9fO+992LLjj322Nh7KuqnJA0fPly9e/fWkiVLNH/+fHXs2FGSdNlll+nuu+/WJ598onPPPVf33ntvtfsp7f4prLfffrvM6aUDDjhAL730kiTp0ksvVZcuXZSTk6P8/Hx99913kqQ1a9aoV69e6tq1q3JycjRlypQy21qzZo0aNWqkP/7xj5KkH3/8UUcffbS6dOmi7Oxs3XrrrdXu37p169S8efPYdEZGhtatW5e07eeff65Vq1bp5JNPliQde+yx6tWrl5o2baqmTZvq9NNPV8eOHbVu3TplZGRUus733ntPP/3pT5WZmSlJ6tKliwoKClRcXKxVq1Zpzpw5Wrt2bZXWFVY/K7Ns2TItW7ZMxx9/vHr27KmpU6dKkkpKSjRy5MjY7zCZiRMnqn///jKz6nZTM2fOVLt27dSmTRs1aNBAAwYM0Msvv1xh+wkTJmjgwIHl5r/wwgvq06ePDjzwQEmSmWnLli2SpM2bN6tZs2aSpJdfflmDBw+WJOXn5+utt94qd7p4woQJZfYzPXv2VNOmTavdtyARwoA6IKgdYEWOPPJI9ejRQ/vvv3+V6xg/frxGjx6t/fbbL7aO6gq6n1u2bNG0adN0zjnnVNrPzZs3691339Wll14qSWrQoEEsBC1btkw///nPJUmnnnqqXnzxxWr3M6iw+cADD2j+/PlasGCBWrRooTFjxkiS7rzzTp1//vmaO3euJk6cqCuvvLLMtq699lr16dMnNt2wYUNNmzZN8+fP17x58zR16lR9+OGH1e5nVU2cOFH5+fmqV6+eJGnFihVavHixCgsLtW7dOk2bNk3vvfdeldaV+DdxySWXKCMjQ3l5eRoxYoSOO+642HZqW030s7i4WMuXL9f06dM1YcIEXX755dq0aZPGjRunM844o0ywTLb9ZP+/VMWehM3Kavjzn/8cq/uZZ57RqFGjym2vfv36OvTQQ7Vhw4Yy65o0adJu96e2EMKAOiCoHaAUuQA4JydH11xzjbZt27bbdXz22WeaNGmS8vLy1KdPHy1fvrzK/avK+hNVt5+S9NJLL+mUU07RIYccUmkdq1atUlpamoYMGaKuXbvqsssu09atWyVJ2dnZsWD4/PPPa+3atVXuX6mgwmZpv9xdP/zwQ+yIR0VHG6TIZ9K6dWtlZ2fH5plZ7ILwHTt2aMeOHdU+epKenl7msyksLFR6enrStom/r7///e/q2bOnGjVqpEaNGqlPnz6aMWOG0tPTVVhYWOE6i4uL9be//U39+/ePzatfv34s0L788svatGmT2rdvv8t1hdnPymRkZKhv377af//91bp1a7Vv317Lly/XjBkzNGbMGLVq1UrXXXednn766VigkaT58+eruLhY3bt3r3YfqysxbJb64osv9Mknn+j000+PzXvggQc0ZcoUFRYWasiQIbr22murtI2PPvpIBx54oDp16lSjtdc0Qhiwj6nODvCuu+7SkiVLNGvWLH3zzTe65557dnu727Zt0wEHHKDZs2fr8ssv1yWXXLLb66qK6vSzVEVhJlFxcbE+/vhj/eY3v9HcuXN10EEHxU4ZPvHEExo3bpy6d++ub7/9Vg0aNKh27UGGzSFDhuhnP/uZlixZoquvvlqSdNttt+nZZ59VRkaGzjjjDD388MOSpO+++0733HNP0tONO3fuVG5uro488kideuqpOuaYY6rVxx49emj58uVatWqVtm/frokTJ6pv377l2i1ZskQbN27UscceG5vXokULvfPOOyouLtaOHTv0zjvvqGPHjmratKkOOeQQffjhh3J3Pf300+rXr1/sff/85z/VoUOHMkeDvv/++1iAfvPNN1W/fn1lZWXtcl1h9rMy55xzTuwOz6+//lrLli1TmzZt9Ne//lVr1qzR6tWr9cc//lEXXXRRmdPcVf3br8iehM1SkydP1rnnnhs78lxUVKT58+fH/rb69++vDz74oNz2iouLtXnzZjVp0mSX20g1hDCgDghiByhJTZs2lZmpYcOGGjJkiGbOnLnbdWRkZMROmZ177rlasGBB1TtYhfUnqk4/pcg/WDNnztSZZ565yzoyMjKUkZER+8chPz9fH3/8sSSpQ4cOeuONNzRnzhwNHDhQbdu2rXL/dkd1w+aTTz6p9evXq2PHjpo0aZKkyD/AF198sQoLCzVlyhRdeOGFKikp0W233aZrrrkm6TAI9erV07x581RYWKiZM2fq008/rVbd9evX15gxY2LXOZ1//vnKzs7W6NGjVVBQUKZ/AwYMKHOkLT8/X23bto0NK9GlSxedffbZkqRx48bpsssuU7t27dS2bdsyp1GT/U189dVX6tatmzp27Kh77rlHzzzzTGxZZesKu5833HCDMjIy9P333ysjI0O33XabJOn0009XkyZNlJWVpV69eunee+8tE04qMnny5D0KLXsSNkslBsHGjRtr8+bNWrZsmaRISC4NoX379tVTTz0lKXLE9+STT459diUlJZo8efJuXXda6yq6Yj9Vf7g7Eihvx44d3rp1a1+5cmXszqRPP/20XLvFixd7y5Ytk979dMwxx/i0adPKzFu/fr27R+4oGz58uN94441llt96661l7hqsrI4bb7zR/+///s/d3d9++23Py8tLmX66u48fP94vuuiipNtN7Kd75I7KJUuWxJZfd9117v6fIQF27tzpF154YazP1VGdO+pyc3P9/fffLzf/T3/6k19++eUVbuOdd97xM888093ds7KyfM2aNbFlrVu39i+//NJPOOEEb9mypbds2dIPPfRQb9y4sT/88MPl1nX77beX+3yw73n11Vc9MzPT27Rp43feeae7u//ud7/zl19+Odbm1ltvLbcfcXdftWqVN2vWLDY8SKm//e1v3qlTJ8/JyfETTzzRP/vsM3d3/+GHHzw/P9/btm3rPXr0iM13j+xfjjnmmHLbuP766z09Pd3NzNPT0/3WW2+tiW7vkhiiAqj7gtgB9urVyzt16uTZ2dl+wQUX+Lfffuvu7l988YWnp6f7wQcf7Iceeqinp6fHhmJIVod7ZEiHM844wzt16uQ9e/b0efPmpUw/3d1PPPFEf+2118rMq6yfc+fO9e7du3vnzp29X79+sXGZ/vSnP3lmZqZnZmb6jTfeuFu3+wcRNktKSmJjLJWUlPjIkSN95MiR7u7eu3dvf/LJJ93dfdGiRd60adNy64wPol999ZVv3LjR3d2///57P+GEE/wf//hHtfsJ7AsIYdjn7e7YUqU2b97s6enpZcbWmThxonfu3NmzsrL8hhtuiM0fP368d+rUKTa4YOlYOatWrfIDDjggtp3SMW7cI2MQderUyTt37uynn366FxUV1fRHgL1MTYfNnTt3+nHHHRcL1YMGDYoFyoULF/pxxx3nOTk53qVLF3/99dfLrTM+hM2fP99zc3O9c+fOnp2dHRv4FEB5lYUwiyzfe+Tl5fns2bPDLgN7kZ07d6p9+/Z68803lZGRoR49emjChAnKyspK2v7hhx/W3Llz9cQTT8TmDR8+XEVFRTr88MM1ZswYbdiwQV27dtWcOXOUlpamwYMH66KLLtIpp5yiLVu2xO5CKygo0Lhx4zR16lStXr1aZ511VrlrZ4qLi9WsWTMtWrRIRxxxhG644QYdeOCBsWs8AAB7LzOb4+55yZYFemG+mfU2s6VmtsLMRiVZ3sLM3jazuWa2wMzOCLIe7Jv29Hb/OXPm6MsvvywzgOnKlSuVmZmptLQ0SdIvfvGL2HhQ8cMbbN26dZe37pd+I9q6davcXVu2bCkzRAAAoG4KLISZWT1JYyX1kZQlaaCZJR56+G9Jk929q6QBksYFVU917ckDV6XyD0mVKn/gqiS9+OKLMjOVHunbsWOHBg8erM6dO6tjx4666667Ym03bdqk/Px8dejQQR07dtzl2DH7sj253b+iUabbtWunpUuXavXq1SouLtZLL71U5q69sWPHqm3btrrhhhv00EMPxeavWrVKXbt21YknnhgbdHH//ffX+PHj1blz59gRsdJBQAEAdVeQR8KOlrTC3Ve6+3ZJEyUlDrLikkoPGxwqaX2A9VTZnoxWXep3v/tdbNTsUmeffXaFt/h/++23evDBB8uMtfP8889r27Zt+uSTTzRnzhw9+uijWr16taSKH5lSXbsbNj///HN169ZNubm5ys7O1iOPPBJ7z/bt2zV06FC1b99eHTp0KDNi+OTJk5WVlaXs7GwNGjQoNr9evXqx7cTf1rxq1Sodc8wxateunfr376/t27fvVj+rKvF2/4pGmW7cuLHGjx+v/v3767/+67/UqlWrMkMEXHXVVfrss890zz336M4775QUGe5hzZo1mjt3ru6//34NGjRIW7Zs0Y4dOzR+/HjNnTtX69evV05OTpnADQCooyq6WGxPfyTlS/pz3PSFksYktGkq6RNJhZI2Suq+q/XWxoX5e/LAVfddPyQ12cNDhw8f7q+88oqfeOKJPmvWLHePXKx91lln+Y4dO/zrr7/2zMxM37Bhg2/atMlbtWq1W3ddxduTB65u27bNf/zxR3d3//bbb71ly5a+bt06d3cfPXq033LLLe4euRi49CLzZcuWeW5ubuwustJb+Sv6TNzdzzvvPJ8wYYK7u19xxRU+bty4avdzT273HzRokDdv3txbtmzpTZo08YMPPjjphdCPPvqoX3/99eXm79y50w855JCk2yr9Xc+cOdNPPvnk2Px33nnH+/TpU+X+AQBSl1L4Ad4DJf3F3TMknSHpGTMrV5OZDTWz2WY2u6ioKPCigjh9VZmPP/5Ya9euLTdIZH5+vg466CA1bdpULVq00HXXXafDDz+80kemVMeeXCvVoEEDNWzYUFJkJPSSkpJYuyeeeEI33XSTJGm//fbTEUccIUl6/PHHddVVV6lx48aSdv3sQHfXtGnTlJ+fL0kaPHhw7EHE1bEngwhWNsr0V199JUnauHFjbFBHSWUex/Pqq6/GHhRcVFSknTt3SopcU7Z8+XK1adNG6enpWrRokUr/tuMHJAQA1F1BhrB1kprHTWdE58W7VNJkSXL3GZIOkHRE4orc/TF3z3P3vNILoVNFVU9fVaSkpETXXnut7rvvvnLLZs6cqXr16mn9+vVatWqV7rvvPq1cubLSR6ZUx54+GmXt2rXKyclR8+bNdeONN6pZs2batGmTpMjp2G7duum8887Tl19+KSnyYONly5bp+OOPV8+ePTV16tTYun788Ufl5eWpZ8+esaC1YcMGHXbYYapfv/4u66vMnoxYXZnhw4crKytLxx9/vEaNGqX27dtLksaMGaPs7Gzl5ubq/vvvj43q/O677yonJ0e5ubnKz8/XI488osMPP1zNmjXTrbfeqp///OfKycnRvHnzdPPNN1e7nwCAvUxFh8j29EdSfUkrJbWW1EDSfEnZCW1ek3Rx9HVHRa4Js8rWm2qnI3fn9FX8qbdNmzZ5kyZNYqNSN2zY0Js2beqzZs3yK6+80p9++ulY2yFDhvikSZP8iy++8JYtW8bmv/vuu37GGWdUu5/PP/+8X3rppbHpp59+OunpU3f3u+++24cNG5Z02bp167xHjx7+73//24uKilySP//88+7uft999/mvf/1rd3c/88wz/ZxzzvHt27f7ypUrPSMjIzbgY2Fhobu7f/bZZ96yZUtfsWKFFxUVedu2bWPbWbNmjWdnZ1e7nwAAhEVhnI5092JJwyS9LmmxIndBLjSzO8ys9FzQSEmXm9l8SROigSz0gcuCOn2VzKGHHqqvv/5aq1ev1urVq9WzZ08VFBQoLy9PLVq00LRp0yRFhjr48MMP1aFDB/3sZz9T8+bNtXTpUknSW2+9VeGYV5WpiefwSVKzZs3UqVMnvffee2rSpIkOPPDA2I0K5513XuyZehkZGerbt6/2339/tW7dWu3bt4+duivdbps2bXTSSSdp7ty5atKkiTZt2qTi4uJd1gcAwN4m0GvC3H2Ku7d397bu/vvovNHuXhB9vcjdj3f3Lu6e6+5vBFlPVQV1+qqiB65W5KqrrtJ3332n7Oxs9ejRQ0OGDFFOTo6kyICiF1xwwR6dvtqTsFlYWKgffvhBUuSaqH/961866qijZGY6++yzNX36dEllA+I555wTm//1119r2bJlatOmjTZu3Kht27bF5r///vvKysqSmalXr1564YUXJElPPfWU+vVLvMEWAIC9EyPm7+OmTJmiESNGaOfOnbrkkkt0yy23aPTo0crLy4sFsttuu00//vhjmSN6b775pkaOHCkzk7tr2LBhGjp0qKTI9WMXXnihNm3apLS0ND355JNq0aKF3F0jR47U1KlTVa9ePd1yyy0aMGCAPvjgA11xxRXab7/9VFJSohEjRsTGyVq5cqUGDBigb775Rl27dtWzzz4buyEAAIBUV9mI+YQwAECtazXq1bBL2C2r7z5z143i7Cv9RMUqC2H1a7sYANhd/IMGpCb+39w9hDCgDmAHWLfw+wT2DWEP1goAALBPIoQBAACEgBAGAAAQAkIYAABACAhhAAAAISCEAQAAhIAQBgAAEALGCUOdx5hLAIBUxJEwAACAEHAkLIm99ciJxNETAAD2FoSwfdjeGjYJmgCAuoDTkQAAACEghAEAAISAEAYAABACQhgAAEAICGEAAAAhIIQBAACEgBAGAAAQgkBDmJn1NrOlZrbCzEYlWf6Amc2L/iwzs01B1gMAAJAqAhus1czqSRor6VRJhZJmmVmBuy8qbePu18S1v1pS16DqAQAASCVBHgk7WtIKd1/p7tslTZTUr5L2AyVNCLAeAACAlBFkCEuXtDZuujA6rxwzaymptaRpFSwfamazzWx2UVFRjRcKAABQ21LlwvwBkl5w953JFrr7Y+6e5+55aWlptVwaAABAzQsyhK2T1DxuOiM6L5kB4lQkAADYhwQZwmZJyjSz1mbWQJGgVZDYyMw6SGosaUaAtQAAAKSUwEKYuxdLGibpdUmLJU1294VmdoeZ9Y1rOkDSRHf3oGoBAABINYENUSFJ7j5F0pSEeaMTpm8LsgYAAIBUlCoX5gMAAOxTCGEAAAAhIIQBAACEgBAGAAAQAkIYAABACAhhAAAAISCEAQAAhIAQBgAAEAJCGAAAQAgIYQAAACEghAEAAISAEAYAABACQhgAAEAICGEAAAAhIIQBAACEgBAGAAAQAkIYAABACAhhAAAAISCEAQAAhIAQBgAAEAJCGAAAQAgCDWFm1tvMlprZCjMbVUGb881skZktNLPngqwHAAAgVdQPasVmVk/SWEmnSiqUNMvMCtx9UVybTEk3STre3Tea2ZFB1QMAAJBKgjwSdrSkFe6+0t23S5ooqV9Cm8sljXX3jZLk7l8FWA8AAEDKCDKEpUtaGzddGJ0Xr72k9mb2vpl9aGa9A6wHAAAgZQR2OrIa28+UdJKkDEnvmllnd98U38jMhkoaKkktWrSo7RoBAABqXJBHwtZJah43nRGdF69QUoG773D3VZKWKRLKynD3x9w9z93z0tLSAisYAACgtgQZwmZJyjSz1mbWQNIASQUJbV5S5CiYzOwIRU5PrgywJgAAgJQQWAhz92JJwyS9LmmxpMnuvtDM7jCzvtFmr0vaYGaLJL0t6Xp33xBUTQAAAKki0GvC3H2KpCkJ80bHvXZJ10Z/AAAA9hmMmA8AABACQhgAAEAICGEAAAAhIIQBAACEgBAGAAAQAkIYAABACAhhAAAAISCEAQAAhIAQBgAAEAJCGAAAQAgIYQAAACEghAEAAISAEAYAABACQhgAAEAICGEAAAAhIIQBAACEgBAGAAAQAkIYAABACAhhAAAAISCEAQAAhIAQBgAAEIJAQ5iZ9TazpWa2wsxGJVl+sZkVmdm86M9lQdYDAACQKuoHtWIzqydprKRTJRVKmmVmBe6+KKHpJHcfFlQdAAAAqSjII2FHS1rh7ivdfbukiZL6Bbg9AACAvUaQISxd0tq46cLovES/MrMFZvaCmTUPsB4AAICUEfaF+f+Q1MrdcyS9KempZI3MbKiZzTaz2UVFRbVaIAAAQBCCDGHrJMUf2cqIzotx9w3uvi06+WdJ3ZOtyN0fc/c8d89LS0sLpFgAAIDaFGQImyUp08xam1kDSQMkFcQ3MLOmcZN9JS0OsB4AAICUEdjdke5ebGbDJL0uqZ6kJ9x9oZndIWm2uxdI+q2Z9ZVULOkbSRcHVQ8AAEAqCSyESZK7T5E0JWHe6LjXN0m6KcgaAAAAUlHYF+YDAADskwhhAAAAIagwhJnZ6WaWn2R+vpmdGmxZAAAAdVtlR8JGS3onyfzpku4IpBoAAIB9RGUhrKG7lxsZ1d2/lnRQcCUBAADUfZWFsEPMrNzdk2a2v6SfBFcSAABA3VdZCPubpMfNLHbUy8waSXokugwAAAC7qbIQ9t+SvpT0uZnNMbOPJa2SVBRdBgAAgN1U4WCt7l4saZSZ3S6pXXT2Cnf/oVYqAwAAqMMqDGFm9suEWS7pMDOb5+7fBlsWAABA3VbZY4vOTjLvcEk5Znapu08LqCYAAIA6r7LTkUOSzTezlpImSzomqKIAAADqumo/tsjdP5e0fwC1AAAA7DOqHcLMrIOkbQHUAgAAsM+o7ML8fyhyMX68wyU1lfTrIIsCAACo6yq7MP+PCdMu6RtFgtivJc0IqigAAIC6rrIL82MP7zazrpIGSTpPkQFbXwy+NAAAgLqrstOR7SUNjP58LWmSJHP3XrVUGwAAQJ1V2enIJZLek3SWu6+QJDO7plaqAgAAqOMquzvyl5K+kPS2mT1uZqdIstopCwAAoG6rMIS5+0vuPkBSB0lvSxoh6UgzG29mp9VWgQAAAHXRLscJc/et7v6cu58tKUPSXEk3VmXlZtbbzJaa2QozG1VJu1+ZmZtZXpUrBwAA2ItVa7BWd9/o7qZG3ZkAABHxSURBVI+5+ym7amtm9SSNldRHUpakgWaWlaTdwZKGS/qoOrUAAADszao9Yn41HC1phbuvdPftkiZK6pek3f9IukfSjwHWAgAAkFKCDGHpktbGTRdG58WYWTdJzd391QDrAAAASDlBhrBKmdl+ku6XNLIKbYea2Wwzm11UVBR8cQAAAAELMoStk9Q8bjojOq/UwZI6SZpuZqsl9ZRUkOzi/Oh1aHnunpeWlhZgyQAAALUjyBA2S1KmmbU2swaSBkgqKF3o7pvd/Qh3b+XurSR9KKmvu88OsCYAAICUEFgIc/diScMkvS5psaTJ7r7QzO4ws75BbRcAAGBvUNlji/aYu0+RNCVh3ugK2p4UZC0AAACpJLQL8wEAAPZlhDAAAIAQEMIAAABCQAgDAAAIASEMAAAgBIQwAACAEBDCAAAAQkAIAwAACAEhDAAAIASEMAAAgBAQwgAAAEJACAMAAAgBIQwAACAEhDAAAIAQEMIAAABCQAgDAAAIASEMAAAgBIQwAACAEBDCAAAAQkAIAwAACAEhDAAAIASBhjAz621mS81shZmNSrL8/5nZJ2Y2z8z+ZWZZQdYDAACQKgILYWZWT9JYSX0kZUkamCRkPefund09V9L/Sro/qHoAAABSSZBHwo6WtMLdV7r7dkkTJfWLb+DuW+ImD5LkAdYDAACQMuoHuO50SWvjpgslHZPYyMyuknStpAaSTg6wHgAAgJQR+oX57j7W3dtKulHSfydrY2ZDzWy2mc0uKiqq3QIBAAACEGQIWyepedx0RnReRSZKOifZAnd/zN3z3D0vLS2tBksEAAAIR5AhbJakTDNrbWYNJA2QVBDfwMwy4ybPlLQ8wHoAAABSRmDXhLl7sZkNk/S6pHqSnnD3hWZ2h6TZ7l4gaZiZ/ULSDkkbJQ0Oqh4AAIBUEuSF+XL3KZKmJMwbHfd6eJDbBwAASFWhX5gPAACwLyKEAQAAhIAQBgAAEAJCGAAAQAgIYQAAACEghAEAAISAEAYAABACQhgAAEAICGEAAAAhIIQBAACEgBAGAAAQAkIYAABACAhhAAAAISCEAQAAhIAQBgAAEAJCGAAAQAgIYQAAACEghAEAAISAEAYAABACQhgAAEAICGEAAAAhIIQBAACEINAQZma9zWypma0ws1FJll9rZovMbIGZvWVmLYOsBwAAIFUEFsLMrJ6ksZL6SMqSNNDMshKazZWU5+45kl6Q9L9B1QMAAJBKgjwSdrSkFe6+0t23S5ooqV98A3d/292/j05+KCkjwHoAAABSRpAhLF3S2rjpwui8ilwq6bVkC8xsqJnNNrPZRUVFNVgiAABAOFLiwnwz+7WkPEn3Jlvu7o+5e56756WlpdVucQAAAAGoH+C610lqHjedEZ1Xhpn9QtItkk50920B1gMAAJAygjwSNktSppm1NrMGkgZIKohvYGZdJT0qqa+7fxVgLQAAACklsBDm7sWShkl6XdJiSZPdfaGZ3WFmfaPN7pXUSNLzZjbPzAoqWB0AAECdEuTpSLn7FElTEuaNjnv9iyC3DwAAkKpS4sJ8AACAfQ0hDAAAIASEMAAAgBAQwgAAAEJACAMAAAgBIQwAACAEhDAAAIAQEMIAAABCQAgDAAAIASEMAAAgBIQwAACAEBDCAAAAQkAIAwAACAEhDAAAIASEMAAAgBAQwgAAAEJACAMAAAgBIQwAACAEhDAAAIAQEMIAAABCQAgDAAAIQaAhzMx6m9lSM1thZqOSLP+5mX1sZsVmlh9kLQAAAKkksBBmZvUkjZXUR1KWpIFmlpXQbI2kiyU9F1QdAAAAqah+gOs+WtIKd18pSWY2UVI/SYtKG7j76uiykgDrAAAASDlBno5Ml7Q2browOg8AAGCft1dcmG9mQ81stpnNLioqCrscAACAPRZkCFsnqXncdEZ0XrW5+2PunufueWlpaTVSHAAAQJiCDGGzJGWaWWszayBpgKSCALcHAACw1wgshLl7saRhkl6XtFjSZHdfaGZ3mFlfSTKzHmZWKOk8SY+a2cKg6gEAAEglQd4dKXefImlKwrzRca9nKXKaEgAAYJ+yV1yYDwAAUNcQwgAAAEJACAMAAAgBIQwAACAEhDAAAIAQEMIAAABCQAgDAAAIASEMAAAgBIQwAACAEBDCAAAAQkAIAwAACAEhDAAAIASEMAAAgBAQwgAAAEJACAMAAAgBIQwAACAEhDAAAIAQEMIAAABCQAgDAAAIASEMAAAgBIQwAACAEAQawsyst5ktNbMVZjYqyfKGZjYpuvwjM2sVZD0AAACpIrAQZmb1JI2V1EdSlqSBZpaV0OxSSRvdvZ2kByTdE1Q9AAAAqSTII2FHS1rh7ivdfbukiZL6JbTpJ+mp6OsXJJ1iZhZgTQAAACkhyBCWLmlt3HRhdF7SNu5eLGmzpCYB1gQAAJASzN2DWbFZvqTe7n5ZdPpCSce4+7C4Np9G2xRGpz+Ltvk6YV1DJQ2NTh4laWkgRdeOIyR9vctWez/6WXfsC32U6Gddsi/0UaKfe4uW7p6WbEH9ADe6TlLzuOmM6LxkbQrNrL6kQyVtSFyRuz8m6bGA6qxVZjbb3fPCriNo9LPu2Bf6KNHPumRf6KNEP+uCIE9HzpKUaWatzayBpAGSChLaFEgaHH2dL2maB3VoDgAAIIUEdiTM3YvNbJik1yXVk/SEuy80szskzXb3Akn/J+kZM1sh6RtFghoAAECdF+TpSLn7FElTEuaNjnv9o6TzgqwhBdWJ06pVQD/rjn2hjxL9rEv2hT5K9HOvF9iF+QAAAKgYjy0CAAAIASGshpjZTjObZ2YLzWy+mY00s/2iy04yMzezs+Pav2JmJ0VfTzez2XHL8sxseoC1PmBmI+KmXzezP8dN32dmN5vZCxW8f7qZVXqnSrTNUjNbYGZLzGyMmR1Wc73YPZX9nmp4O6X9nxf9yY/OdzN7Nq5dfTMrMrNXarqGatS6q7/dV6KvLzazEjPLiXvvpzxuDEB1VHOfUxS3H51nZllm1iq6L706bp1jzOzikLq02whhNecHd89192xJpyryuKZb45YXSrqlkvcfaWZ9giwwzvuSjpOk6B/+EZKy45Yfp8idqvl7uJ0L3D1HUo6kbZJe3sP11YRd/Z5q0gXRbeW6e2mg3Sqpk5n9JDp9qsoP3VLbqvOZ7OrvOHB1cQeeSn0ys/9nZhfVWOfKrz9l+lrT6nLfalh19jmT4vajue6+KDr/K0nDLTL6wl6LEBYAd/9KkcFlh5nFHsM0X9JmMzu1grfdq9r7x+0DScdGX2dL+lTSt2bW2MwaSuoo6RuLDKYrM/uJmU00s8Vm9ndJpQFCZnaamc0ws4/N7Hkza5S4sehjq26Q1MLMukTf92szmxnd+TxqkWeNVrg+M1ttZv9rZp9E39duTz+ExN+TmdUzs3vNbJZFjuBdEdfP6+Pm3x6d18oiR/n+Gv1sXjCzA6uw6SmSzoy+Hihpwp72paZU8Lcb7xVJ2WZ2VO1WVkZK7sAtMtbh7kqZPrn7I+7+dFXb70a/U6avVVHN/u1VfatM6T45aFXY51SkSNJb+s8wV3slQlhA3H2lIkNzHBk3+/eS/ruCt8yQtN3MetVCbeslFZtZC0WOes2Q9JEiwSxP0ieStse95TeSvnf3jorsULpLkpkdoUh/fuHu3STNlnRtBdvcqUgQ7WBmHSX1l3S8u+dK2inpgiqsb7O7d5Y0RtKf9viDULnf06XRbfSQ1EPS5RYZ5+40SZmKPA81V1J3M/t5dBVHSRoX/Wy2SLoybvV/jfuWG/84romSBpjZAYocJfyoJvpSUyr42y1VIul/Jd1cq0VVoLZ24GbWIxrA50WDeukXlIvNrMDMpkl6K/5oR3R5tY9i1GKfLo9+sZhvZi+WfoEws9vM7Lro60D7nQJ9bWtmH0a/3N1pZt9F559kZu+ZWYGkRdEvXJ/Gre86M7stxfv2F4teBhGdLu3bfmY2LvoF8k0zm2L/uVxitZndY2YfSzrP4i49MbMjzGx1NfpRZbvY5/S3skcMfxK37B5J19VWYAwCIawWufu7kmRmJ1TQ5E5VHNJq2geKBLDSEDYjbvr9hLY/l/SsJLn7AkkLovN7SsqS9L6ZzVNkp9Gykm2W7ohOUSTIzYq+7xRJbaqwvglx/z1WNe80SRdFt/2RIs8xzYzOP03SXEkfS+oQnS9Ja9299PN6VlL87zb+dGTsSRDRz7CVIkfBygzhspd4TlJPM2sddiFSre3An5R0RdyXhnjdJOW7+4m7U38ytdSnv7l7D3fvImmxIl9CEgXe75D7+qCkB6Nf7goT3tNN0nB3b1+tDsVJod9jvF8qsv/JknShyu9LN7h7N3efWIVt14bEI4Y/lC6Ifr4fSRoUXnl7JtBxwvZlZtZGkZ3WV4qc3itVejSsOPE97j7NzO5UJIwErfS6sM6KnI5cK2mkIkdznqziOkzSm+4+cJcNIzuTzorsJI6U9JS735TQ5uxdrM8reL3bEn5PJulqd389oc3pku5y90cT5rdKUkdV6yqQ9EdJJynFHlpfyd+upNhAzPdJurG2a9sNk+KfVytJpQcl3H2lme1yB26RG0oOdvcZ0VnPSTorrsmb7v5NzZW8S3vcp6hO0f3NYZIaKTKwdvw6U6HfQff1WEnnRF8/p8j/k6VmuvuqPah9V2rl95jECZKed/cSSf82s7cT66pK8TVpV/ucXfiDpBckvVPTddUGjoQFwMzSJD0iaUziY5jc/Q1JjRU5DZXMnYpcPxW0DxTZoX7j7jujO9PDFNkpfZDQ9l1FdwZm1kn/qf1DScdb9PosMzvIzMp9azSz/SXdpchRowWKHGrPN7Mjo8sPN7OWVVhf/7j/ztAeSvJ7el3Sb6L1yszam9lB0fmX2H+uT0svrV2R69xKv0kOkvSvKm7+CUm3u/sne9qPmlTZ326Cv0j6haSkD6WtTQk78Or6gyJhsjqnixJtjXtdrLL71QN2Z4W11Ke/SBoWPQp0u6pfa430O4X7usf9C7lvsZotcnNAVa8xq6jfu/W3vCvV2Ock5e5LJC2SdPau2qYiQljN+Un0kPJCSf+U9IYi/0Mk83uVfbh5jEeeMlAUTIllfKLIXZEfJszb7O6JT6sfL6mRmS2WdIekOdFaiyRdLGmCmS1QJBh1iHvfX6PzP5V0kKR+0fctUuRo4BvR5W9KalqF9TWOzh8u6Zrd7Hdlv6c/K/I/88fR6z8elVQ/GpyfkzTDzD5R5FvXwdH3LJV0VfSzaRz9rHbJ3Qvd/aHd7ENNq87frqTYzRYPKflpllpTGztwd9+kyI0rx0RnVfZ4tc8lZZlZw+iRpFOqW1Mt/qN0sKQvol86LkiynsD7nQJ9/VDSr6KvK+vfl4rcwd7EIjcvnVVJW0kp0bfVil6/K6mvpP2jr9+X9KvotWE/VeRofEXi17Gnd8vHq84+J/G07XFJ2vxeUkYN1ldrOB1ZQ9y9wnP37j5d0vS46QLFfbtx95MS2ndXwKIXyh+SMO/iuNerJXWKvv5BFeyg3H2aIhexJ84/aRfbn6Qkh70rWl/Uve6+R6fAdvF7KlHkgvNyF527+4OKXD8SEz0dWezuv07S/qQKtpHs7tHpivv7qG1V/dt1978o8q27dNlDigSx2vaT6HV7+yvyTf0ZSfdX0LZ/wjWYV0pan9Dm94pc71eZSyU9bmYlipz22JyskbuvNbPJinzxWFWF9ZYKo0+/U+R6mqLofw9O0iaIfqdSX0dIetbMbpE0VRX3b4dFnns8U5EhZZZUsJ1U6tvjkl42s/mK9K30CNeLioTkRYpchvKxKui3IqdnJ5vZUEmv7qKOKtvdfU6CTnHvma+99KASjy3CXsEid+XkJTlKF5poCHvF3Tvtoin2cmbWyN1L7y4bpciR2+EhlxW4ut5vi9xJ+IO7u5kNkDTQ3fuFXVfQSn+vFrlre6Yid6r/O+y69kUcCcNewd1bhV1DovijhajzzjSzmxTZZ36uyGnzfUFd73d3SWMsckX8JkmXhFxPbXkletq4gaT/IYCFhyNhABBlZmMlHZ8w+0F3r+odwymnLvapInW5r3W5b/syQhgAAEAI9soL2QAAAPZ2hDAAAIAQEMIAAABCQAgDAAAIASEMAAAgBP8fbH9wyiXGbJ4AAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 720x360 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "_, ax = plt.subplots(figsize=(10, 5))\n",
    "rect = ax.bar(list(scores.keys()), list(scores.values()))\n",
    "ax.set_ylabel('AUC')\n",
    "ax.set_ylim(top=0.9)\n",
    "autolabel(rect)\n",
    "plt.plot()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
